-
Notifications
You must be signed in to change notification settings - Fork 6.8k
AdamW operator (Fixing Weight Decay Regularization in Adam) #13728
Conversation
@sxjscience @szhengac could you guys help review this PR? |
python/mxnet/optimizer/optimizer.py
Outdated
rescaled_grad = clip(grad * rescale_grad, clip_gradient) | ||
m = beta1 * m + (1 - beta1) * rescaled_grad | ||
v = beta2 * v + (1 - beta2) * (rescaled_grad**2) | ||
w = w - learning_rate * (m / (sqrt(v) + epsilon) + wd * w) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the paper, it has two learning rates. An alpha before m / (sqrt(v) + epsilon).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. The issue is that the learning rate and schedule multiplier is not decoupled in MXNet. Here learning_rate
is effectively eta_t * alpha
in the paper and wd
actually needs to be set as w / alpha
. In another word wd
can be rescaled properly so that it does exactly the same thing in the paper. Would this be acceptable? Is so maybe I can move this to contrib for the moment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's acceptable as long as the wd
is set correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought I think it's better to keep it consistent with the paper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
Can you please provide an output of an end to end use case using AdamW optimizer?
@sandeep-krishnamurthy training/fine-tuning the BERT model in GluonNLP would be a use case of AdamW |
python/mxnet/optimizer/optimizer.py
Outdated
kwargs['clip_gradient'] = self.clip_gradient | ||
|
||
mean, var = state | ||
adamw_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we set wd
to something like wd / self._original_lr
?
@sxjscience @szhengac I took a step back and moved the operator to contrib and use the same notation as the one in the paper. I think the optimizer API still needs more discussion, so I removed it from the PR. |
…3728) * tests * remove optimizer and move op to contrib * rename parameter
…3728) * tests * remove optimizer and move op to contrib * rename parameter
Description
Implement a modification of Adam in "Fixing Weight Decay Regularization in Adam" https://arxiv.org/abs/1711.05101.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.